{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <center>Анализ аварий на ЖД транспорте США в 2013 году \n",
    "## <center>и страховых выплат по ним\n",
    "<div style=\"border: 2px solid black; margin-left: 15%; margin-right: 15%\"><img style=\"\" src='https://www.atsb.gov.au/media/2444492/ro2010012_fig1.jpg' /></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Описание набора данных и признаков\n",
    "\n",
    "В этом проекте исследуются данные об инцидентах грузового железнодорожного транспорта США за 2013 год и соответствующие запросы на страховое возмещение ущербра от перевозчиков. Данные взяты с <a href=\"http://www.explore-support.com/help/sample-data-sets\">Cisco Data Explore</a>.\n",
    "\n",
    "Датасет содержит следующие признаки:\n",
    "<table class='desc_table' align='left' width='100%'>\n",
    "<tr><th style=\"text-align: left !important\">Признак</th><th style=\"text-align: left !important\">Описание</th></tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">DEPARTURE CITY</td>\n",
    "<td style=\"text-align: left !important\">Город отправления груза (вагона)</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">DEPARTURE STATE</td>\n",
    "<td style=\"text-align: left !important\">Штат отправления груза (вагона)</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">DEPARTURE CARRIER</td>\n",
    "<td style=\"text-align: left !important\">Перевозчик отправления, отправитель</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">ARRIVAL CITY</td>\n",
    "<td style=\"text-align: left !important\">Город прибытия груза (вагона)</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">ARRIVAL STATE</td>\n",
    "<td style=\"text-align: left !important\">Штат прибытия груза (вагона)</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">ARRIVAL CARRIER</td>\n",
    "<td style=\"text-align: left !important\">Компания-перевозчик прибытия, принимающая сторона</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">RAIL SPEED SPEED</td>\n",
    "<td style=\"text-align: left !important\">Тип скороcти железной дороги</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">RAIL CAR TYPE TYPE</td>\n",
    "<td style=\"text-align: left !important\">Тип вагона</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">RAIL OWNERSHIP OWNERSHIP</td>\n",
    "<td style=\"text-align: left !important\">Тип собственности железной дороги</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">RAIL CARLOAD LOAD</td>\n",
    "<td style=\"text-align: left !important\">Тип груза</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">DEPEARTURE DATE</td>\n",
    "<td style=\"text-align: left !important\">Дата отправления</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">ARRIVAL DATE</td>\n",
    "<td style=\"text-align: left !important\">Дата прибытия</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">CAR VALUE</td>\n",
    "<td style=\"text-align: left !important\">Стоимость вагона, USD</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td style=\"text-align: left !important\"><b>DAMAGED</b></td>\n",
    "<td style=\"text-align: left !important\"><b>Размер ущерба, USD</b></td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">WEIGHT</td>\n",
    "<td style=\"text-align: left !important\">Вес груза</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">FUEL USED</td>\n",
    "<td style=\"text-align: left !important\">Количествто израсходованного топлива</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">PROPER DESTINATION</td>\n",
    "<td style=\"text-align: left !important\">Метка правильности назначения</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\">MILES</td>\n",
    "<td style=\"text-align: left !important\">Пройденный путь</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td style=\"text-align: left !important\"># OF STOPS</td>\n",
    "<td style=\"text-align: left !important\">Количество остановок в пути</td>\n",
    "</tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<b>Задача данного пректа</b> - попытаться предсказать размер ущерба, полученного в результате инцидента, а также предоставить другую полезную информацию. \n",
    "\n",
    "<b>Ценность результатов проекта</b> - информация для страховых компаний, позволяющая быть более гибкими в расчете стоимости страховки для тех или иных компаний, грузов, направлений и т. д., а также для самих перевозчиков - для прогноза затрат на перевозку грузов, выбора более безопасных путей, времени и других параметров для транспортировки грузов."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Первичный анализ данных"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Импортируем все нужные библиотеки:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelBinarizer\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold, cross_val_score\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "import xgboost as xgb\n",
    "\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь посмотрим на наши данные:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_file = '../../data/Rail_Insurance_Claims.csv'\n",
    "data = pd.read_csv(data_file, sep=',', parse_dates=['DEPEARTURE DATE','ARRIVAL DATE'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.describe(include=['object'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.describe(include='datetime64')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Видим, что в датасете большинство признаков - категориальные, целевая переменная и кол-во топлива - вещественные, два временных признака и несколько количественных. В данных нет пропусков. Также в данных есть много парных признаков DEPARTURE и ARRIVAL. Отсюда можно сделать следующие промежуточные выводы:\n",
    " * Категориальные признаки понадобится кодировать (OHE, mean target)\n",
    " * Из парных признаков можно извлечь много дополинтельной информации и создать новые признаки, связанные с путем следования, временем в пути и т.д.\n",
    " * В данных на первый взгляд нет пропусков - нужно проверить значения категориальных признаков на смысловые пропуски: значения типа N/A, unknown и т. д.\n",
    " * Из числовых признаков можно создать новые относительные.\n",
    " * Из временных признаков также можно извлечь дополнительную информацию\n",
    " * В данных присутствуют признаки CAR VALUE и DAMAGED - целесообразно предсказывать не сам ущерб, а его процент\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['damaged percent'] = data['DAMAGED'] / data['CAR VALUE']\n",
    "data['damaged percent'].describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Видим, что значения находятся в интервали от 0 до 1, т.е. признак корректный. Также можно заметить, что 75-я квантиль на порядок меньше максимального значения, из чего можно сделать вывод, что целевая переменная имееть сильный дисбаланс."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Посмотрим на распределения категориальных признаков и целевой переменной: построим несколько pivot-таблиц по категориальным признакам."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.pivot_table( ['damaged percent'], ['DEPARTURE STATE'], ['ARRIVAL STATE'],  aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Посмотрим на распределение по штатам отправления:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data['DEPARTURE STATE'].value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Создадим дополнительные временные признаки и посмотрим, возможно, есть перекос по дням или месяцам в ущербе для штата отправления TN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['Dep Month'] = data['DEPEARTURE DATE'].dt.month\n",
    "data['Dep Day'] = data['DEPEARTURE DATE'].dt.day"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data[data['DEPARTURE STATE'] == 'TN'].pivot_table( ['damaged percent'], ['Dep Month'], ['Dep Day'], aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Продолжим с другими категориями:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.pivot_table( ['damaged percent'], ['DEPARTURE CARRIER'], ['ARRIVAL CARRIER'],  aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.pivot_table( ['damaged percent'], ['RAIL CARLOAD LOAD'], ['RAIL OWNERSHIP OWNERSHIP'],  aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.pivot_table(['damaged percent'], ['RAIL CAR TYPE TYPE'], ['RAIL SPEED SPEED'], aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.pivot_table(['damaged percent'], ['PROPER DESTINATION'], aggfunc='mean')) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Построим корреляционную матрицу:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_m = data.corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "round((corr_m), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Проверим целевую переменную на нормальность и скошенность:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import shapiro, skewtest, skew\n",
    "\n",
    "print('Normality: {}'.format(shapiro(data['damaged percent'])))\n",
    "print('Skewness: {}'.format(skew(data['damaged percent'])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Проверим, нет ли в категориальных признаках значений, похожих на пропуски:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in data.select_dtypes(include=['object']):\n",
    "    print('{}: {}'.format(c, data[c].unique()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Выводы:\n",
    "К предыдущим выводам можно добавить следующее:\n",
    "* Вагоны, следующие из штата TN, терпят гораздо больший ущерб, чем из всех остальных штатов, но судя по распределению кол-ва рейсов из этого штата и средних потерь по датам, не похоже, что это выброс. Возможно, это будет хорошим предиктором.\n",
    "* Перевозчики-отправители CN и  NS страдают немного сильнее, чем остальные, однако в ределах стандартного отклонения\n",
    "* В распределении остальных категорий отностильено таргета ничего необычного не замечено\n",
    "* С целевой переменной немного коррелируют стоимость вагона, вес груза, использованное топливо и кол-во остановок\n",
    "* Целевая переменная не распределена нормально и имеет скошенность с тяжелым левым хвостом, понадобится ее преобразовать\n",
    "* Выбросов и пропусков найти не удалось\n",
    "* Значения переменной PROPER DESTINATION нужно преобразовать в [0, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Первичный визуальный анализ данных"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Для начала дополним данные недостающими временными признаками, преобразуем переменную PROPER DESTINATION:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['Dep DayOfWeek'] = data['DEPEARTURE DATE'].dt.weekday\n",
    "data['Dep weekend'] = data['Dep DayOfWeek'].isin([5,6]).astype('int')\n",
    "\n",
    "data['Arr Month'] = data['ARRIVAL DATE'].dt.month\n",
    "data['Arr Day'] = data['ARRIVAL DATE'].dt.day\n",
    "data['Arr DayOfWeek'] = data['ARRIVAL DATE'].dt.weekday\n",
    "data['Arr weekend'] = data['Arr DayOfWeek'].isin([5,6]).astype('int')\n",
    "\n",
    "data['proper_dest'] = data['PROPER DESTINATION'].map({'Yes': 1, 'No':0})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['duration'] = (data['ARRIVAL DATE'] - data['DEPEARTURE DATE']).dt.days"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Отобразим корреляционную матрицу:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_m = data.corr()\n",
    "plt.figure(figsize=(15, 15))\n",
    "sns.heatmap(np.abs(c_m), annot=True, fmt=\".2f\", linewidths=.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Видна корреляция между днем отправления и днем прибытия, месяцем отправления и месяцем прибытия, что говорит о наличии расписания, корреляция между днем недели и выходным также ясна, как и между использованным топливом и весом груза. Интересна корреляция между правильным направлением и весом, нуждается в доп. изучении. Видно, что добавленные признаки немного коррелируют с целевой переменной. Переменную DAMAGED нужно удалить, она может быть получена из CAR VALUE и damaged percent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.drop(['DAMAGED'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Отобразим плотности распределения числовых величин и колличество значений категориальных:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clmns = data.select_dtypes(exclude=['object','datetime64', 'bool']).columns\n",
    "f, axarr = plt.subplots(ncols=1, nrows=len(clmns), figsize=(15, 40))\n",
    "\n",
    "for c in clmns:\n",
    "    sns.distplot(data[c], ax=axarr[list(clmns).index(c)])\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = data.select_dtypes('object').columns\n",
    "\n",
    "f, axarr = plt.subplots(ncols=1, nrows=len(categories), figsize=(15, 40))\n",
    "\n",
    "i = 0\n",
    "for cat in categories:\n",
    "    g = sns.countplot(x=cat, data=data, ax=axarr[i], palette=\"Blues_d\")\n",
    "    \n",
    "    r = 30\n",
    "    if (cat.endswith('CITY')):\n",
    "        r = 90\n",
    "    g.set_xticklabels(g.get_xticklabels(), rotation=r)\n",
    "    \n",
    "    i += 1\n",
    "plt.tight_layout(h_pad=0.5)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = list(data.select_dtypes('object').columns) + ['# OF STOPS', 'Dep Month', 'Dep Day', 'Dep DayOfWeek', 'Arr Month', 'Arr Day', 'Arr DayOfWeek', 'Dep weekend', 'Arr weekend']\n",
    "\n",
    "f, axarr = plt.subplots(ncols=1, nrows=len(categories), figsize=(15, 80))\n",
    "\n",
    "i = 0\n",
    "for cat in categories:\n",
    "    g = sns.barplot(x=cat, y='damaged percent', data=data, ax=axarr[i])\n",
    "    \n",
    "    r = 30\n",
    "    if (cat.endswith('CITY')):\n",
    "        r = 90\n",
    "    g.set_xticklabels(g.get_xticklabels(), rotation=r)\n",
    "    \n",
    "    i += 1\n",
    "plt.tight_layout(h_pad=0.5)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_tbl = (data.pivot_table(['damaged percent'], ['DEPARTURE STATE'], ['ARRIVAL STATE'], aggfunc='mean')) * 100\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "sns.heatmap(state_tbl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Выводы\n",
    "Видно, что результаты визуального анализа отображают закономерности, выявленные в предыдущей части. Распределения величин не указывает на наличие выбросов. Почти все значения признаков имеют разные средние значения damaged percent, т.е. должны быть учтены при прогнозе."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Инсайты и закономерности"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Часть закономерностей описана выше.\n",
    "Следует также создать признаки, связанные с путем, закодировать некоторые из них OHE, для путей как таковых применить mean target encoding."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Выбор метрики и модели"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Т.к. в данных нет выбросов, а решаемая задача - регрессия, можно использовать MSE или R2. Т.к. R2 - это по сути 1 - усредненная MSE, то используем первую для большей наглядности. \n",
    "\n",
    "В качестве модели будем сравнивать линейную регрессию и градиентный бустинг, т.к. они обе подходят для задачи регрессии, а градиентный бустинг хорошо себя зарекомендовал в решениях задач со смешанным типом признаков. При использовании линейной регрессии будем масштабировать признаки."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. Предобработка данных и создание новых признаков"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Частично предобработка была выполнена выше для визуализации.\n",
    "\n",
    "Создадим признаки, связанные с путем:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['interstate'] = (data['DEPARTURE STATE'] != data['ARRIVAL STATE']).astype('int')\n",
    "data['intercity'] = (data['DEPARTURE CITY'] != data['ARRIVAL CITY']).astype('int')\n",
    "data['intercarrier'] = (data['DEPARTURE CARRIER'] != data['ARRIVAL CARRIER']).astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['city route'] = data['DEPARTURE CITY'] + data['ARRIVAL CITY']\n",
    "data['state route'] = data['DEPARTURE STATE'] + data['ARRIVAL STATE']\n",
    "data['carrier route'] = data['DEPARTURE CARRIER'] + data['ARRIVAL CARRIER']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Создадим относительный признак расход доплива: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['FUEL PER MILE'] = (data['FUEL USED']/data['MILES'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Напишем функцию для кодирования средним:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_target_enc(train_df, y_train, valid_df, cat_features, skf):\n",
    "    import warnings\n",
    "    warnings.filterwarnings('ignore')\n",
    "    target_name = y_train.name\n",
    "    \n",
    "    glob_mean = y_train.mean()\n",
    "    train_df = pd.concat([train_df, pd.Series(y_train, name='y')], axis=1)\n",
    "    new_train_df = train_df.copy()  \n",
    "\n",
    "    for col in cat_features:\n",
    "        new_train_df[col + '_mean_' + target_name] = [glob_mean for _ in range(new_train_df.shape[0])]\n",
    "\n",
    "    for train_idx, valid_idx in skf.split(train_df, y_train):\n",
    "        train_df_cv, valid_df_cv = train_df.iloc[train_idx, :], train_df.iloc[valid_idx, :]\n",
    "\n",
    "        for col in cat_features:\n",
    "            \n",
    "            means = valid_df_cv[col].map(train_df_cv.groupby(col)['y'].mean())\n",
    "            valid_df_cv[col + '_mean_' + target_name] = means.fillna(glob_mean)\n",
    "            \n",
    "        new_train_df.iloc[valid_idx] = valid_df_cv\n",
    "    \n",
    "    new_train_df.drop(['y'], axis=1, inplace=True)\n",
    "    \n",
    "    for col in cat_features:\n",
    "        means = valid_df[col].map(train_df.groupby(col)['y'].mean())\n",
    "        valid_df[col + '_mean_' + target_name] = means.fillna(glob_mean)\n",
    "        \n",
    "#     valid_df.drop(cat_features, axis=1, inplace=True)\n",
    "    \n",
    "    return new_train_df, valid_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Применим OHE преобразование к категориальным признакам, за исключением путей:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_ohe = pd.get_dummies(data, columns=['DEPARTURE CITY', 'DEPARTURE STATE', 'DEPARTURE CARRIER',\n",
    "       'ARRIVAL CITY', 'ARRIVAL STATE', 'ARRIVAL CARRIER', 'RAIL SPEED SPEED',\n",
    "       'RAIL CAR TYPE TYPE', 'RAIL OWNERSHIP OWNERSHIP', 'RAIL CARLOAD LOAD'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Применим mean target encoding к путям"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_ohe, _ = mean_target_enc(data_ohe, (data['damaged percent']*10000).astype('int'), data_ohe[-1:], ['city route', 'state route', 'carrier route'], StratifiedKFold(5, shuffle=True, random_state=17))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Удалим ненужные фичи и выделим целевую переменную:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = data['damaged percent']\n",
    "data_ohe.drop(['DEPEARTURE DATE', 'ARRIVAL DATE', 'FUEL USED', 'city route', 'state route', 'carrier route', 'PROPER DESTINATION'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_ohe.drop(['damaged percent'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import scipy.stats as stats\n",
    "\n",
    "stats.probplot(y, dist=\"norm\", plot=plt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "stats.probplot(np.log(y), dist=\"norm\", plot=plt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "stats.probplot(StandardScaler().fit_transform(y.values.reshape(-1,1).astype(np.float64)).flatten(), dist=\"norm\", plot=plt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.log(y)\n",
    "y.hist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Выделим тренировочную, валидационную и тестовые выборки. Т.к. данные у нас сбалансированы, выберем случайный способ. Отмасштабируем признаки."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "st = StandardScaler()\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(data_ohe, y, test_size=0.3, random_state=1)\n",
    "X_train_st = st.fit_transform(X_train)\n",
    "X_val_st = st.transform(X_val)\n",
    "\n",
    "X_val_st, X_test_st, y_val, y_test = train_test_split(X_val_st, y_val, test_size=0.3333, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = LinearRegression(n_jobs=-1)\n",
    "\n",
    "cv_sc = cross_val_score(lr, X_train, y_train,  cv=5, n_jobs=-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cv_sc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr.fit(X_train, y_train)\n",
    "lr_pred_val = lr.predict(X_val_st)\n",
    "\n",
    "r2_score(y_val, lr_pred_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_test_pred = lr.predict((X_test_st))\n",
    "print(r2_score(y_test, lr_test_pred))\n",
    "np.sqrt(mean_squared_error(y_test, lr_test_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 10))\n",
    "plt.plot(np.exp(y_test).values[-200:], 'b')\n",
    "plt.plot(np.exp(lr_test_pred)[-200:], 'g')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Видим, что линейная регрессия не сработала. Посмотрим на xboost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "dtrain = xgb.DMatrix(X_train_st, label=np.sqrt(y_train))\n",
    "dtest = xgb.DMatrix(X_val_st)\n",
    "\n",
    "params = {\n",
    "    'objective':'reg:linear',\n",
    "    'max_depth':5,\n",
    "    'silent':1,\n",
    "    'nthread': 8,\n",
    "#     'booster': 'dart',\n",
    "#     'eta':0.5,\n",
    "#     'gamma': 0.1,\n",
    "#     'lambda': 20,\n",
    "#     'alpha': 0.5\n",
    "}\n",
    "\n",
    "num_rounds = 100\n",
    "xgb_ = xgb.train(params, dtrain, num_rounds)\n",
    "\n",
    "xgb__pred = xgb_.predict(dtest)\n",
    "\n",
    "r2_score(y_val, (xgb__pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "np.sqrt(mean_squared_error(np.sqrt(y_val), xgb__pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "np.sqrt(mean_squared_error(y_val, np.exp(xgb__pred)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xgb_test_pred = xgb_.predict(xgb.DMatrix(X_test_st))\n",
    "print(r2_score(y_test, xgb_test_pred))\n",
    "np.sqrt(mean_squared_error(y_test, xgb_test_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 10))\n",
    "plt.plot(np.exp(y_test).values[-200:], 'b')\n",
    "plt.plot(np.exp(xgb_test_pred)[-200:], 'g')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "А xgboost справляется неплохо, но предсказания получились смещенными вверх."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Выводы:\n",
    "\n",
    "Удалось проанализировать датасет с вполне приемлемыми результатами. Линейная регрессия плохо показала себя в этой задаче, а градиентный бустинг - довольно хорошо. \n",
    "\n",
    "Для улучшения результатов можно выделить больше признаков или закодировать средним больше категориальных фич. Также можно преобразовать задачу в бинарную классификацию, определяя обычные и большие убытки. \n",
    "\n",
    "Используя полученную модель, можно довольно хорошо предсказывать крупные потери при траспортировке, что будет полезно как страховым компаниям, так и перевозчикам."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
